{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d6e72215-501e-4bb0-ada9-7fb61849b524",
   "metadata": {},
   "source": [
    "# Jupyter执行LLaMA Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62202b1e-ca0e-493e-8809-6e8deeb7c6bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='6,7'\n",
    "print(torch.cuda.device_count())\n",
    "import transformers\n",
    "from peft import PeftModel\n",
    "from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c6eed29-07b0-49f5-a314-ea73313f1118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/jiw203/wjn/InstructGraph/examples/demo\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4c24d7-f6ed-4c27-9bba-851564027054",
   "metadata": {},
   "source": [
    "### 加载LLaMA2与PEFT模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "a674b2ec-215e-44ee-879b-fbc3d7b063a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"/home/jiw203/wjn/pre-trained-lm/Llama-2-7b-hf/\"\n",
    "peft_model = \"/home/jiw203/wjn/InstructGraph/output/instruction_tuning/fsdp_peft_flash_1500k/llama2-peft\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "449eaa15-5c6a-47e7-92d5-ebe8c9e86f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to load the main model for text generation\n",
    "def load_model(model_name, quantization):\n",
    "    model = LlamaForCausalLM.from_pretrained(\n",
    "        model_name,\n",
    "        return_dict=True,\n",
    "        load_in_8bit=quantization,\n",
    "        device_map=\"auto\",\n",
    "        low_cpu_mem_usage=True,\n",
    "    )\n",
    "    return model\n",
    "\n",
    "\n",
    "# Function to load the PeftModel for performance optimization\n",
    "def load_peft_model(model, peft_model):\n",
    "    peft_model = PeftModel.from_pretrained(model, peft_model)\n",
    "    return peft_model\n",
    "\n",
    "# Loading the model from config to load FSDP checkpoints into that\n",
    "def load_llama_from_config(config_path):\n",
    "    model_config = LlamaConfig.from_pretrained(config_path) \n",
    "    model = LlamaForCausalLM(config=model_config)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "8d51e532-8130-4eab-939b-2718b603c181",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c0dec3a04d594a5fa92d1b4ee075a64b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = load_model(model_name_or_path, quantization=False)\n",
    "model = load_peft_model(model, peft_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "26cd638f-a644-4831-9716-7fbe8209f405",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModelForCausalLM(\n",
       "  (base_model): LoraModel(\n",
       "    (model): LlamaForCausalLM(\n",
       "      (model): LlamaModel(\n",
       "        (embed_tokens): Embedding(32000, 4096)\n",
       "        (layers): ModuleList(\n",
       "          (0-31): 32 x LlamaDecoderLayer(\n",
       "            (self_attn): LlamaAttention(\n",
       "              (q_proj): Linear(\n",
       "                in_features=4096, out_features=4096, bias=False\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.05, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=4096, out_features=32, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=32, out_features=4096, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "              (v_proj): Linear(\n",
       "                in_features=4096, out_features=4096, bias=False\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.05, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=4096, out_features=32, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=32, out_features=4096, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "              (rotary_emb): LlamaRotaryEmbedding()\n",
       "            )\n",
       "            (mlp): LlamaMLP(\n",
       "              (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
       "              (act_fn): SiLUActivation()\n",
       "            )\n",
       "            (input_layernorm): LlamaRMSNorm()\n",
       "            (post_attention_layernorm): LlamaRMSNorm()\n",
       "          )\n",
       "        )\n",
       "        (norm): LlamaRMSNorm()\n",
       "      )\n",
       "      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1a0396fe-d82a-4feb-93b8-acea79a6bd40",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efba0e8e-fe4a-44d6-96cd-66b6b089199f",
   "metadata": {},
   "source": [
    "### 自定义输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "ad2f2e40-635b-4aad-9bf7-9958942dcbef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# instruction = (\n",
    "#     'You are a good graph reasoner. Give you a graph language that describes a graph structure and node information. You need to understand the graph and the task definition, and answer the question.\\nNote: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \\n```\\nGraph[name=\"factual-graph\"] {\\n    entity_list = [\\'Raid on Unadilla and Onaquaga\\', \\'New York\\'];\\n    triple_list = [(\"Raid on Unadilla and Onaquaga\" -> \"New York\")[relation=\"located in the administrative territorial entity\"], (\"Raid on Unadilla and Onaquaga\" -> \"New York\")[relation=\"location\"]];\\n}\\n```\\nTask definition: given an event title and a factual knowledge graph, generate an event narration.\\nQ: Please generate an event narration based on the knowledge graph and the title \"Raid on Unadilla and Onaquaga\".\\nA:'\n",
    "# )\n",
    "instruction = \"\"\"You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
    "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
    "Task definition: given a factual knowledge question, please think step by step: 1) find the topic entity and generate a corresponding knowledge subgraph to express the world knowledge information, 2) then generate a thinking explanation to describe this reasoning, and 3) finally output the final answer. Note that: 1) the graph is a directed-weighted graph and the name is \"factual-knowledge-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
    "```\n",
    "Graph[name=\"factual-knowledge-graph\"] {\n",
    "    entity_list = [\"xxx\", ...];\n",
    "    triple_list = [(\"xxx\" -> \"xxx\")[relation=\"xxx\"], ...];\n",
    "}\n",
    "```\n",
    "Q: What's the birth country of the film \"Avatar\"'s director?\n",
    "A:\"\"\"\n",
    "\n",
    "# instruction = \"What's your name?\"\n",
    "batch = tokenizer(instruction, return_tensors=\"pt\")\n",
    "# print(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea55b93-6449-40e5-8cea-249bc74f2a9c",
   "metadata": {},
   "source": [
    "### 模型生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "1c0b5500-c373-466c-a309-c541bc5f0be9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
      "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
      "Task definition: given a factual knowledge question, please think step by step: 1) find the topic entity and generate a corresponding knowledge subgraph to express the world knowledge information, 2) then generate a thinking explanation to describe this reasoning, and 3) finally output the final answer. Note that: 1) the graph is a directed-weighted graph and the name is \"factual-knowledge-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
      "```\n",
      "Graph[name=\"factual-knowledge-graph\"] {\n",
      "    entity_list = [\"xxx\", ...];\n",
      "    triple_list = [(\"xxx\" -> \"xxx\")[relation=\"xxx\"], ...];\n",
      "}\n",
      "```\n",
      "Q: What's the birth country of the film \"Avatar\"'s director?\n",
      "A:To answer this question, we first find the topic entity is \"Avatar\".\n",
      "Then, we construct a knowledge subgraph of the topic entity, the graph language is:\n",
      "```\n",
      "Graph[name=\"freebase-knowledge-base\"] {\n",
      "    entity_list = ['James Cameron', 'Film director', 'Canada', 'Avatar'];\n",
      "    triple_list = [(\"Avatar\" -> \"Canada\")[relation=\"country of origin\"], (\"Avatar\" -> \"James Cameron\")[relation=\"director\"], (\"Film director\" -> \"Avatar\")[relation=\"Films directed\"]];\n",
      "}\n",
      "```\n",
      "Based on the graph, we can find a reasoning path ('Avatar', 'country of origin', 'Canada'), so the answer entity is \"Canada\".\n"
     ]
    }
   ],
   "source": [
    "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(\n",
    "        **batch,\n",
    "        max_new_tokens=500,\n",
    "        do_sample=False,\n",
    "        top_p=1.0,\n",
    "        temperature=1.0,\n",
    "        min_length=None,\n",
    "        use_cache=True,\n",
    "        top_k=50,\n",
    "        repetition_penalty=1.0,\n",
    "        length_penalty=1\n",
    "    )\n",
    "output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "print(output_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c387e354-4d42-4066-9d25-99976fb99569",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b23064c1-2722-4325-8a48-60cde1c4bfaf",
   "metadata": {},
   "source": [
    "## 模型生成能力初检"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a577541-d905-463a-afa3-a0522612c5bb",
   "metadata": {},
   "source": [
    "graph structure modeling\n",
    "- Connectivity：可以生成，指令理解、内容连贯度和形式都符合预期，但准确率需要优化\n",
    "- Cycle：可以生成，指令理解、内容连贯度和形式都符合预期，但准确率需要优化\n",
    "- Topological Sort：可以生成，指令理解、内容连贯度和形式都符合预期，但准确率需要优化\n",
    "- Shortest Path：可以生成，指令理解、内容连贯度和形式都符合预期，但准确率需要优化\n",
    "- Maximum Flow：可以生成，指令理解、内容连贯度和形式都符合预期，但准确率需要优化\n",
    "- Bipartite Graph Matching：可以生成，指令理解、内容连贯度和形式都符合预期，\n",
    "- Hamilton Path：可以生成，指令理解、内容连贯度和形式都符合预期，但生成结果存在一定问题\n",
    "- Graph Degree：训练数据存在错误：计算degree但是标签却为yes/no\n",
    "\n",
    "graph caption generation:\n",
    "- wikipedia：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- webnlg：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- agenda：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- genwiki：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- EventNarrative：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- XAlign：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "graph question answering：\n",
    "- pathquestions：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- wc2014：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- grailqa：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- webquestion：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- WikiTableQuestions：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "graph node cls：\n",
    "- cora：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- citeseer：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- pubmed：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- ogbn-arxiv：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- ogbn-product：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "graph link prediction\n",
    "- wikidata：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- FB15k-237：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- ConceptNet：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "Graph Relevance Inspection：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "Graph Collaboration Filtering：\n",
    "- moivelens：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- Amazon：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- LastFM：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "knowledge graph generation：（调整一下顺序，将passage放在q的前面）\n",
    "- wikidata：可以生成，指令理解、内容连贯度和形式都符合预期题\n",
    "- instructuie：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "- instructkgc：可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "Structure Graph Generation：（调整一下顺序，将passage放在q的前面）\n",
    "可以生成，指令理解、内容连贯度和形式都符合预期\n",
    "\n",
    "Graph thought modeling: 没问题，但是graph language\n",
    "probing：没问题，但是graph language"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93cb4492-c34a-40d8-84bc-628025509731",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41565356-14fc-4325-8b9d-098c8690cd30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b84bbd59-8521-44c8-b25c-a0ffa5037f13",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 错误记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e32d8a25-b27c-4c41-8666-fbda868b4ac7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nYou are a good graph reasoner. Give you a graph language that describe a graph structure and node infotmation. You need to understanding the graph and the task definition, and answer the question.\\nNote: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with an directed edge. \\n```\\nGraph[name=\"maximum-flow\"] {\\n    node_list = [0, 1, 2, 3, 4, 5];\\n    edge_list = [(0 -> 4)[capacity=3], (0 -> 1)[capacity=6], (0 -> 2)[capacity=3], (1 -> 2)[capacity=4], (1 -> 0)[capacity=4], (2 -> 4)[capacity=9], (2 -> 3)[capacity=7], (3 -> 4)[capacity=4], (3 -> 0)[capacity=5], (4 -> 1)[capacity=4], (4 -> 2)[capacity=4], (4 -> 0)[capacity=2], (5 -> 0)[capacity=3]];\\n}\\n```\\nTask definition: calculate the maximum flow between two nodes in the graph.\\nQ: What is the maximum flow from node 2 to node 4?\\nA:\\n\\n```\\nmax_flow_from_node_2_to_node_4 = 12\\n```\\n'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "You are a good graph reasoner. Give you a graph language that describe a graph structure and node infotmation. You need to understanding the graph and the task definition, and answer the question.\n",
    "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with an directed edge. \n",
    "```\n",
    "Graph[name=\"maximum-flow\"] {\n",
    "    node_list = [0, 1, 2, 3, 4, 5];\n",
    "    edge_list = [(0 -> 4)[capacity=3], (0 -> 1)[capacity=6], (0 -> 2)[capacity=3], (1 -> 2)[capacity=4], (1 -> 0)[capacity=4], (2 -> 4)[capacity=9], (2 -> 3)[capacity=7], (3 -> 4)[capacity=4], (3 -> 0)[capacity=5], (4 -> 1)[capacity=4], (4 -> 2)[capacity=4], (4 -> 0)[capacity=2], (5 -> 0)[capacity=3]];\n",
    "}\n",
    "```\n",
    "Task definition: calculate the maximum flow between two nodes in the graph.\n",
    "Q: What is the maximum flow from node 2 to node 4?\n",
    "A:\n",
    "\n",
    "```\n",
    "max_flow_from_node_2_to_node_4 = 12\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "f103e7e9-2bf3-4e4d-88d7-433caa19e6b9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
      "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
      "Task definition: given a graph structure description, generate a corresponding graph language. Note that: 1) the graph is an undirected-weighted graph and the name is \"undirected-weighted-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
      "```\n",
      "Graph[name=\"undirected-weighted-graph\"] {\n",
      "    entity_list = [\"xxx\", ...];\n",
      "    triple_list = [(\"xxx\" <-> \"xxx\")[weight=\"xxx\"], ...];\n",
      "}\n",
      "```\n",
      "Q: Given you a graph structure description, please generate a corresponding graph.\n",
      "Description: \"In an undirected graph, the nodes are numbered from 0 to 4, and the edges are:\n",
      "an edge between node 0 and node 2 with weight 20,\n",
      "an edge between node 0 and node 3 with weight 1,\n",
      "an edge between node 1 and node 2 with weight 2,\n",
      "an edge between node 1 and node 4 with weight 2,\n",
      "an edge between node 1 and node 3 with weight 3,\n",
      "an edge between node 2 and node 4 with weight -3,\n",
      "an edge between node 2 and node 3 with weight 3.4,\n",
      "an edge between node 3 and node 4 with weight 1.\".\n",
      "A: \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[1,\n",
       " 887,\n",
       " 526,\n",
       " 263,\n",
       " 1781,\n",
       " 3983,\n",
       " 15299,\n",
       " 29889,\n",
       " 887,\n",
       " 817,\n",
       " 304,\n",
       " 2274,\n",
       " 278,\n",
       " 3414,\n",
       " 5023,\n",
       " 322,\n",
       " 5706,\n",
       " 263,\n",
       " 3983,\n",
       " 4086,\n",
       " 304,\n",
       " 1234,\n",
       " 278,\n",
       " 1139,\n",
       " 29889,\n",
       " 29871,\n",
       " 13,\n",
       " 9842,\n",
       " 29901,\n",
       " 313,\n",
       " 29875,\n",
       " 529,\n",
       " 976,\n",
       " 432,\n",
       " 29897,\n",
       " 2794,\n",
       " 393,\n",
       " 2943,\n",
       " 474,\n",
       " 322,\n",
       " 2943,\n",
       " 432,\n",
       " 526,\n",
       " 6631,\n",
       " 411,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 7636,\n",
       " 29889,\n",
       " 313,\n",
       " 29875,\n",
       " 1599,\n",
       " 432,\n",
       " 29897,\n",
       " 2794,\n",
       " 393,\n",
       " 2943,\n",
       " 474,\n",
       " 322,\n",
       " 2943,\n",
       " 432,\n",
       " 526,\n",
       " 6631,\n",
       " 411,\n",
       " 263,\n",
       " 10624,\n",
       " 7636,\n",
       " 29889,\n",
       " 29871,\n",
       " 13,\n",
       " 5398,\n",
       " 5023,\n",
       " 29901,\n",
       " 2183,\n",
       " 263,\n",
       " 3983,\n",
       " 3829,\n",
       " 6139,\n",
       " 29892,\n",
       " 5706,\n",
       " 263,\n",
       " 6590,\n",
       " 3983,\n",
       " 4086,\n",
       " 29889,\n",
       " 3940,\n",
       " 393,\n",
       " 29901,\n",
       " 29871,\n",
       " 29896,\n",
       " 29897,\n",
       " 278,\n",
       " 3983,\n",
       " 338,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 3983,\n",
       " 322,\n",
       " 278,\n",
       " 1024,\n",
       " 338,\n",
       " 376,\n",
       " 870,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 29899,\n",
       " 4262,\n",
       " 1642,\n",
       " 29871,\n",
       " 29906,\n",
       " 29897,\n",
       " 450,\n",
       " 5759,\n",
       " 3983,\n",
       " 4086,\n",
       " 881,\n",
       " 367,\n",
       " 263,\n",
       " 775,\n",
       " 29899,\n",
       " 4561,\n",
       " 3829,\n",
       " 29892,\n",
       " 322,\n",
       " 278,\n",
       " 18109,\n",
       " 11285,\n",
       " 3402,\n",
       " 508,\n",
       " 367,\n",
       " 13384,\n",
       " 408,\n",
       " 278,\n",
       " 1494,\n",
       " 29901,\n",
       " 13,\n",
       " 28956,\n",
       " 13,\n",
       " 9527,\n",
       " 29961,\n",
       " 978,\n",
       " 543,\n",
       " 870,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 29899,\n",
       " 4262,\n",
       " 3108,\n",
       " 426,\n",
       " 13,\n",
       " 1678,\n",
       " 7855,\n",
       " 29918,\n",
       " 1761,\n",
       " 353,\n",
       " 6796,\n",
       " 12353,\n",
       " 613,\n",
       " 2023,\n",
       " 1385,\n",
       " 13,\n",
       " 1678,\n",
       " 21954,\n",
       " 29918,\n",
       " 1761,\n",
       " 353,\n",
       " 518,\n",
       " 703,\n",
       " 12353,\n",
       " 29908,\n",
       " 529,\n",
       " 976,\n",
       " 376,\n",
       " 12353,\n",
       " 1159,\n",
       " 29961,\n",
       " 7915,\n",
       " 543,\n",
       " 12353,\n",
       " 12436,\n",
       " 2023,\n",
       " 1385,\n",
       " 13,\n",
       " 29913,\n",
       " 13,\n",
       " 28956,\n",
       " 13,\n",
       " 29984,\n",
       " 29901,\n",
       " 11221,\n",
       " 366,\n",
       " 263,\n",
       " 3983,\n",
       " 3829,\n",
       " 6139,\n",
       " 29892,\n",
       " 3113,\n",
       " 5706,\n",
       " 263,\n",
       " 6590,\n",
       " 3983,\n",
       " 29889,\n",
       " 13,\n",
       " 9868,\n",
       " 29901,\n",
       " 376,\n",
       " 797,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 3983,\n",
       " 29892,\n",
       " 278,\n",
       " 7573,\n",
       " 526,\n",
       " 1353,\n",
       " 287,\n",
       " 515,\n",
       " 29871,\n",
       " 29900,\n",
       " 304,\n",
       " 29871,\n",
       " 29946,\n",
       " 29892,\n",
       " 322,\n",
       " 278,\n",
       " 12770,\n",
       " 526,\n",
       " 29901,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29900,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29900,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29900,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29896,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29941,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 448,\n",
       " 29941,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29941,\n",
       " 29889,\n",
       " 29946,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29896,\n",
       " 1213,\n",
       " 29889,\n",
       " 13,\n",
       " 29909,\n",
       " 29901,\n",
       " 29871]"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(instruction)\n",
    "a = list(tokenizer.encode(instruction))\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "522c76ea-bbf0-475b-ade7-ceef8674233d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<s>', '▁A', ':', 'The', '▁thought', '▁graph', '▁that', '▁express', 'es']"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens([1, 319, 29901, 1576, 2714, 3983, 393, 4653, 267])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "a2d1e50c-70a2-4eb1-813f-1c010752e544",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_tokens_to_ids(\"</s>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "689718b6-02b4-43bf-9431-d2adad2995b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9d605e-8029-44c7-bb66-0a245fd23d43",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "1a75da00-cf2f-4346-b988-899e51fbc68d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 319, 29901, 1576, 2714, 3983, 393, 4653, 267, 278, 24481, 10757, 338, 7521, 13, 9527, 29961, 978, 543, 23147, 292, 29899, 386, 1774, 29899, 4262, 3108, 426, 13, 1678, 2943, 29918, 1761, 353, 6796, 25719, 613, 376, 8835, 4982, 613, 376, 1287, 613, 376, 9057, 613, 376, 8835, 10370, 13, 1678, 7636, 29918, 1761, 353, 518, 703, 25719, 29908, 1599, 376, 8835, 4982, 29908, 29961, 5030, 5946, 543, 29874, 326, 3108, 511, 4852, 1287, 29908, 1599, 376, 9057, 29908, 29961, 5030, 5946, 543, 6689, 310, 2125, 2058, 3108, 511, 4852, 9057, 29908, 1599, 376, 8835, 29908, 29961, 5030, 5946, 543, 28111, 20068, 1385, 13, 29913, 13, 16159, 1412, 29871, 13, 1576, 1234, 881, 367, 278, 7306, 310, 2305, 472, 664, 29889, 4587, 278, 2038, 19995, 29892, 278, 7306, 310, 2305, 472, 664, 338, 304, 4866, 278, 4982, 29889, 29871, 13, 6295, 278, 1234, 338, 319, 29889, 29871, 2]\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "IGNORE_INDEX = -100\n",
    "instruction = \"\"\"A:\"\"\"\n",
    "\n",
    "answer = \"\"\"The thought graph that expresses the reasoning evidence is ```\\nGraph[name=\"reasoning-thought-graph\"] {\\n    node_list = [\"people\", \"complete job\", \"work\", \"job\", \"complete\"];\\n    edge_list = [(\"people\" -> \"complete job\"[capacity=\"aim\"]), (\"work\" -> \"job\"[capacity=\"place of take place\"]), (\"job\" -> \"complete\"[capacity=\"goal\"])];\\n}\\n```. \\nThe answer should be the goal of people at work. Of the above choices, the goal of people at work is to complete the job. \\nSo the answer is A. \"\"\"\n",
    "\n",
    "\n",
    "prompt = instruction\n",
    "example = prompt + answer\n",
    "\n",
    "prompt = tokenizer.encode(prompt)\n",
    "example = tokenizer.encode(example) # list\n",
    "\n",
    "if len(example) > 2048 - 1:\n",
    "    example = example[:2047]\n",
    "\n",
    "example.append(tokenizer.eos_token_id)\n",
    "print(example)\n",
    "prompt = torch.tensor(\n",
    "    prompt, dtype=torch.int64\n",
    ")\n",
    "\n",
    "example = torch.tensor(\n",
    "    example, dtype=torch.int64\n",
    ")\n",
    "\n",
    "labels = copy.deepcopy(example)\n",
    "labels[: len(prompt)] = -1\n",
    "example_mask = example.ge(0)\n",
    "label_mask = labels.ge(0)\n",
    "example[~example_mask] = 0\n",
    "labels[~label_mask] = IGNORE_INDEX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "fdce3f03-595d-4b04-b7d5-0fbdac2cc37d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ -100,  -100,  -100,  1576,  2714,  3983,   393,  4653,   267,   278,\n",
      "        24481, 10757,   338,  7521,    13,  9527, 29961,   978,   543, 23147,\n",
      "          292, 29899,   386,  1774, 29899,  4262,  3108,   426,    13,  1678,\n",
      "         2943, 29918,  1761,   353,  6796, 25719,   613,   376,  8835,  4982,\n",
      "          613,   376,  1287,   613,   376,  9057,   613,   376,  8835, 10370,\n",
      "           13,  1678,  7636, 29918,  1761,   353,   518,   703, 25719, 29908,\n",
      "         1599,   376,  8835,  4982, 29908, 29961,  5030,  5946,   543, 29874,\n",
      "          326,  3108,   511,  4852,  1287, 29908,  1599,   376,  9057, 29908,\n",
      "        29961,  5030,  5946,   543,  6689,   310,  2125,  2058,  3108,   511,\n",
      "         4852,  9057, 29908,  1599,   376,  8835, 29908, 29961,  5030,  5946,\n",
      "          543, 28111, 20068,  1385,    13, 29913,    13, 16159,  1412, 29871,\n",
      "           13,  1576,  1234,   881,   367,   278,  7306,   310,  2305,   472,\n",
      "          664, 29889,  4587,   278,  2038, 19995, 29892,   278,  7306,   310,\n",
      "         2305,   472,   664,   338,   304,  4866,   278,  4982, 29889, 29871,\n",
      "           13,  6295,   278,  1234,   338,   319, 29889, 29871,     2])\n"
     ]
    }
   ],
   "source": [
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a33d04b0-6e40-4600-8e65-fc0caaa04986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5],\n",
       " [1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5],\n",
       " [1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5]]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[[1, 2, 3], [1, -2, 3], [4, 4, 5]]*3"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
